package tailnet

import (
	"context"
	"io"
	"net"
	"strconv"
	"strings"
	"sync/atomic"

	"github.com/google/uuid"
	"github.com/hashicorp/yamux"
	"storj.io/drpc/drpcmux"
	"storj.io/drpc/drpcserver"

	"cdr.dev/slog"
	"github.com/coder/coder/v2/tailnet/proto"

	"golang.org/x/xerrors"
)

const (
	CurrentMajor = 2
	CurrentMinor = 0
)

var SupportedMajors = []int{2, 1}

func ValidateVersion(version string) error {
	major, minor, err := parseVersion(version)
	if err != nil {
		return err
	}
	if major > CurrentMajor {
		return xerrors.Errorf("server is at version %d.%d, behind requested version %s",
			CurrentMajor, CurrentMinor, version)
	}
	if major == CurrentMajor {
		if minor > CurrentMinor {
			return xerrors.Errorf("server is at version %d.%d, behind requested version %s",
				CurrentMajor, CurrentMinor, version)
		}
		return nil
	}
	for _, mjr := range SupportedMajors {
		if major == mjr {
			return nil
		}
	}
	return xerrors.Errorf("version %s is no longer supported", version)
}

func parseVersion(version string) (major int, minor int, err error) {
	parts := strings.Split(version, ".")
	if len(parts) != 2 {
		return 0, 0, xerrors.Errorf("invalid version string: %s", version)
	}
	major, err = strconv.Atoi(parts[0])
	if err != nil {
		return 0, 0, xerrors.Errorf("invalid major version: %s", version)
	}
	minor, err = strconv.Atoi(parts[1])
	if err != nil {
		return 0, 0, xerrors.Errorf("invalid minor version: %s", version)
	}
	return major, minor, nil
}

type streamIDContextKey struct{}

// StreamID identifies the caller of the CoordinateTailnet RPC.  We store this
// on the context, since the information is extracted at the HTTP layer for
// remote clients of the API, or set outside tailnet for local clients (e.g.
// Coderd's single_tailnet)
type StreamID struct {
	Name string
	ID   uuid.UUID
	Auth TunnelAuth
}

func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
	return context.WithValue(ctx, streamIDContextKey{}, streamID)
}

// ClientService is a tailnet coordination service that accepts a connection and version from a
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
type ClientService struct {
	logger   slog.Logger
	coordPtr *atomic.Pointer[Coordinator]
	drpc     *drpcserver.Server
}

// NewClientService returns a ClientService based on the given Coordinator pointer.  The pointer is
// loaded on each processed connection.
func NewClientService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) (*ClientService, error) {
	s := &ClientService{logger: logger, coordPtr: coordPtr}
	mux := drpcmux.New()
	drpcService := NewDRPCService(logger, coordPtr)
	err := proto.DRPCRegisterClient(mux, drpcService)
	if err != nil {
		return nil, xerrors.Errorf("register DRPC service: %w", err)
	}
	server := drpcserver.NewWithOptions(mux, drpcserver.Options{
		Log: func(err error) {
			if xerrors.Is(err, io.EOF) {
				return
			}
			logger.Debug(context.Background(), "drpc server error", slog.Error(err))
		},
	})
	s.drpc = server
	return s, nil
}

func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
	major, _, err := parseVersion(version)
	if err != nil {
		s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
		return err
	}
	switch major {
	case 1:
		coord := *(s.coordPtr.Load())
		return coord.ServeClient(conn, id, agent)
	case 2:
		config := yamux.DefaultConfig()
		config.LogOutput = io.Discard
		session, err := yamux.Server(conn, config)
		if err != nil {
			return xerrors.Errorf("yamux init failed: %w", err)
		}
		auth := ClientTunnelAuth{AgentID: agent}
		streamID := StreamID{
			Name: "client",
			ID:   id,
			Auth: auth,
		}
		ctx = WithStreamID(ctx, streamID)
		return s.drpc.Serve(ctx, session)
	default:
		s.logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
		return xerrors.New("unsupported version")
	}
}

// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
type DRPCService struct {
	coordPtr *atomic.Pointer[Coordinator]
	logger   slog.Logger
}

func NewDRPCService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) *DRPCService {
	return &DRPCService{
		coordPtr: coordPtr,
		logger:   logger,
	}
}

func (*DRPCService) StreamDERPMaps(*proto.StreamDERPMapsRequest, proto.DRPCClient_StreamDERPMapsStream) error {
	// TODO integrate with Dean's PR implementation
	return xerrors.New("unimplemented")
}

func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailnetStream) error {
	ctx := stream.Context()
	streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
	if !ok {
		_ = stream.Close()
		return xerrors.New("no Stream ID")
	}
	logger := s.logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
	logger.Debug(ctx, "starting tailnet Coordinate")
	coord := *(s.coordPtr.Load())
	reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
	c := communicator{
		logger: logger,
		stream: stream,
		reqs:   reqs,
		resps:  resps,
	}
	c.communicate()
	return nil
}

type communicator struct {
	logger slog.Logger
	stream proto.DRPCClient_CoordinateTailnetStream
	reqs   chan<- *proto.CoordinateRequest
	resps  <-chan *proto.CoordinateResponse
}

func (c communicator) communicate() {
	go c.loopReq()
	c.loopResp()
}

func (c communicator) loopReq() {
	ctx := c.stream.Context()
	defer close(c.reqs)
	for {
		req, err := c.stream.Recv()
		if err != nil {
			c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err))
			return
		}
		err = SendCtx(ctx, c.reqs, req)
		if err != nil {
			c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err()))
			return
		}
	}
}

func (c communicator) loopResp() {
	ctx := c.stream.Context()
	defer func() {
		err := c.stream.Close()
		if err != nil {
			c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err))
		}
	}()
	for {
		resp, err := RecvCtx(ctx, c.resps)
		if err != nil {
			c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err))
			return
		}
		err = c.stream.Send(resp)
		if err != nil {
			c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err))
			return
		}
	}
}
